go-kratos unit testing

Mar 25, 2025 • 10 min read

package biz
 
import (
	"context"
	"io"
	"reflect"
	"testing"
 
	"github.com/go-kratos/kratos/v2/log"
)
 
// 想测 GreeterUsecase 的 CreateGreeter 方法,就按下面这个格式写。
func TestGreeterUsecase_CreateGreeter(t *testing.T) {
	// testCases 定义了所有测试用例。
	// 可能这个写法看起来有点怪,但实际上就是个元素是匿名结构体的切片。
	testCases := []struct {
		name   string                 // 测试名称
		create func() *GreeterUsecase // 构造函数
		ctx    context.Context        // 参数 1
		arg    *Greeter               // 参数 2
		want   *Greeter               // 期望响应
		err    error                  // 期望错误
	}{
		{
			name: "normal",
			create: func() *GreeterUsecase {
				// 正常情况下怎么创建 use case 对象,测试时就怎么创建。
				// 因为你测试的是 biz 而不是 repo。所以应该自己创建 stub repo。
				// 如果想测执行了什么 sql 的话,去 data 层测 repo 去。
				//
				// 可能你会担心 stub 和现实不一样导致测不出 bug 来。
				// 那你就应该在 GreeterRepo 接口的文档注释中写明白
				// 传什么参数返什么结果;有什么问题报什么错。让 stub 和实际业务保持一致。
				//
				// 即使实际就是和接口的文档注释不一致,真出了 bug。
				// 那也应该是让 repo 实现和文档一致,和 biz 没有关系,不是么?
				//
				// 比如 GreeterRepo 的接口文档可以这么写:
				//
				//	// GreeterRepo 接口定义了用户数据访问接口。
				//	type GreeterRepo interface {
				//		// Save 方法将用户保存到数据库中,并返回保存后的用户。
				//		// 如果……,则返回 ErrUserNotFound。
				//		Save(context.Context, *Greeter) (*Greeter, error)
				//	}
				var stub = &stubGreeterRepo{
					saveFn: func(ctx context.Context, a *Greeter) (*Greeter, error) {
						return a, nil
					},
				}
 
				return NewGreeterUsecase(
					stub,
					log.NewStdLogger(io.Discard),
				)
			},
			ctx: context.Background(),
			arg: &Greeter{
				Hello: "foo",
			},
			want: &Greeter{
				Hello: "foo",
			},
			err: nil,
		},
		{
			name: "not found",
			create: func() *GreeterUsecase {
				return NewGreeterUsecase(
					&stubGreeterRepo{
						saveFn: func(ctx context.Context, a *Greeter) (*Greeter, error) {
							return nil, ErrUserNotFound
						},
					},
					log.NewStdLogger(io.Discard),
				)
			},
			ctx: context.Background(),
			arg: &Greeter{
				Hello: "not found",
			},
			want: nil,
			err:  ErrUserNotFound,
		},
		// 当线上业务出 bug,找到原因之后,就把对应的参数和返回值添加到测试里。
		// 以后就肯定不会再出问题了。
	}
	// for 循环测试用例
	for _, tc := range testCases {
		// 为每个测试用例执行一个子测试
		t.Run(tc.name, func(t *testing.T) {
			// 构造对象,调用方法,传入期望的参数。
			uc := tc.create()
			res, err := uc.CreateGreeter(tc.ctx, tc.arg)
 
			// 对比返回的结果是否与期望一致。
			if err != tc.err {
				t.Fatalf("err != tc.err. want %v, got %v", tc.err, err)
			}
			// 数值类型及其结构体是可以直接用 ”==“ 判等的,但切片、map 等类型或包含它们的类型则不能如此。
			// 这些类型可以使用 reflect.DeepEqual() 进行判等操作。
			if !reflect.DeepEqual(res, tc.want) {
				t.Fatalf("res != tc.want. want %v, got %v", tc.want, res)
			}
		})
	}
}
 
// stubGreeterRepo 实现了 GreeterRepo 接口,用于进行测试。
// 如果不传 saveFn,那么调用 Save 方法就会 panic。不必担心,就应该是这样。
// 请记住写测试就是为了发现问题的。为什么项目之前的提交不报错,到你这儿就 panic 了?
// 比起研究“黑魔法”去解决 panic,不如去查查业务逻辑到底改了什么才导致 panic 的。
type stubGreeterRepo struct {
	saveFn        func(context.Context, *Greeter) (*Greeter, error)
	updateFn      func(context.Context, *Greeter) (*Greeter, error)
	findByIDFn    func(context.Context, int64) (*Greeter, error)
	listByHelloFn func(context.Context, string) ([]*Greeter, error)
	listAllFn     func(context.Context) ([]*Greeter, error)
}
 
func (r *stubGreeterRepo) Save(ctx context.Context, a *Greeter) (*Greeter, error) {
	return r.saveFn(ctx, a)
}
func (r *stubGreeterRepo) Update(ctx context.Context, a *Greeter) (*Greeter, error) {
	return r.updateFn(ctx, a)
}
func (r *stubGreeterRepo) FindByID(ctx context.Context, a int64) (*Greeter, error) {
	return r.findByIDFn(ctx, a)
}
func (r *stubGreeterRepo) ListByHello(ctx context.Context, a string) ([]*Greeter, error) {
	return r.listByHelloFn(ctx, a)
}
func (r *stubGreeterRepo) ListAll(ctx context.Context) ([]*Greeter, error) {
	return r.listAllFn(ctx)
}